背景介绍
CGAN(Conditional Generative Adversarial Networks, 条件生成式对抗网络):于2014年提出,引入标签变量,可以通过控制其标签变量的值,产生不同类别的图像,其网络结构和GAN基本类似,只是多了一些条件变量的处理。
CGAN特点
生成器的输入有两个,一个是随机数,一个是标签数据的one-hot编码形式,利用Concatenate层将两个输入融合。
判别器的输入也有两个,一个是输入图像,一个是标签数据的one-hot编码形式,首先利用Flatten将输入图像转化维一维向量,然后利用Concatenate层将两个输入融合。
CGAN图像分析
TensorFlow2.0实现
1 | import os |
模型运行结果
小技巧
- 图像输入可以先将其归一化到0-1之间或者-1-1之间,因为网络的参数一般都比较小,所以归一化后计算方便,收敛较快。
- 注意其中的一些维度变换和numpy,tensorflow常用操作,否则在阅读代码时可能会产生一些困难。
- 可以设置一些权重的保存方式,学习率的下降方式和早停方式。
- CGAN对于网络结构,优化器参数,网络层的一些超参数都是非常敏感的,效果不好不容易发现原因,这可能需要较多的工程实践经验。
- 先创建判别器,然后进行compile,这样判别器就固定了,然后创建生成器时,不要训练判别器,需要将判别器的trainable改成False,此时不会影响之前固定的判别器,这个可以通过模型的_collection_collected_trainable_weights属性查看,如果该属性为空,则模型不训练,否则模型可以训练,compile之后,该属性固定,无论后面如何修改trainable,只要不重新compile,都不影响训练。
- CGAN的网络结构和GAN基本相同,因此也只适合小图像的生成,其中生成器使用Concatenate实现标签数据和输入随机数的结合,判别器使用Concatenate实现标签数据和输入图像的结合,一定要注意首先要对图像进行Flatten处理,否则会出错。
CGAN小结
CGGAN是一种简单的生成式对抗网络,从上图可以看出CGAN模型的参数量只有2M,和普通的GAN网络差不多,通过CGAN可以实现指定类别的图像生成,不再是完全的随机数产生,因此对于实际的工程应用是有意义的,值得小伙伴们学习。